Skip to content

Fix DeepSpeed import crash on runtime-only CUDA and improve NVFP4 uncalibrated weight error#896

Open
debo3 wants to merge 1 commit intoNVIDIA:mainfrom
debo3:fix/nvfp4-export-large-models
Open

Fix DeepSpeed import crash on runtime-only CUDA and improve NVFP4 uncalibrated weight error#896
debo3 wants to merge 1 commit intoNVIDIA:mainfrom
debo3:fix/nvfp4-export-large-models

Conversation

@debo3
Copy link

@debo3 debo3 commented Feb 16, 2026

What does this PR do?

Type of change: Bug fix

Overview: Fixes two crashes that block NVFP4 quantization of large models (>1TB) on production GPU infrastructure.

Bug 1 — DeepSpeed import crashes on runtime-only CUDA systems:

During mtq.quantize(), make_deepspeed_compatible imports DeepSpeed to check for ZeRO-3 compatibility. DeepSpeed's import chain calls nvcc --version to check CUDA compiler compatibility. On runtime-only CUDA installations (NGC containers, cloud GPU instances without the CUDA toolkit), this raises FileNotFoundError. The existing except ImportError doesn't catch it, so quantization crashes before calibration even starts — even though the user isn't using DeepSpeed at all.

Fix: Broaden the exception handler to also catch FileNotFoundError and RuntimeError.

Bug 2 — Opaque assertion when weight quantizers lack _amax:

NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer() uses assert hasattr(weight_quantizer, "_amax") which produces an opaque AssertionError with no guidance on what went wrong. This is hit when accelerate's device_map="auto" offloads layers to disk on large models — the quantizers are inserted but some may not accumulate _amax during calibration. The user loses hours of calibration time to an error that doesn't explain the cause or fix.

Fix: Replace bare assert with ValueError that explains why _amax is missing (disk offloading, insufficient calib_size) and points to _ensure_weight_quantizer_calibrated() as the resolution.

Note: PR #785 added _ensure_weight_quantizer_calibrated() in quant_utils.py which handles this case at the call site. This PR improves the safety net in nvfp4_tensor.py itself — if _ensure_weight_quantizer_calibrated is bypassed or fails, the user gets a useful error instead of an opaque assert.

Usage

# Before this fix — crashes on systems without nvcc:
# FileNotFoundError: [Errno 2] No such file or directory: '/usr/local/cuda/bin/nvcc'
import modelopt.torch.quantization as mtq
model = mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop=calib_loop)

# After this fix — works correctly, DeepSpeed check is silently skipped

Testing

Tested on:

  • DeepSeek V3 671B BF16 model (1.3TB) on 8x NVIDIA B200 GPUs (SM100)
  • Runtime-only CUDA 12.8 environment (no nvcc installed)
  • PyTorch 2.9.1+cu128, ModelOpt 0.41.0
  • Quantization with NVFP4_DEFAULT_CFG, device_map="auto", 1024 calibration samples

Before fix: FileNotFoundError at mtq.quantize() (Bug 1), AssertionError at export_hf_checkpoint() (Bug 2)
After fix: Both operations complete successfully

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes — only broadens exception handling and improves error messages. No API or behavior changes.
  • Did you write any new necessary tests?: No — these are edge cases in exception handling. Could add a unit test that mocks a failed DeepSpeed import.
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No — minor bug fix, not a feature or breaking change.

Additional Information

Related: PR #785 (Fix a nvfp4 weight amax attribute issue during export) added _ensure_weight_quantizer_calibrated() which addresses the _amax issue at the quant_utils.py call site. This PR adds the safety net at the nvfp4_tensor.py level.

Encountered while quantizing fine-tuned DeepSeek V3 671B models on NVIDIA B200 Blackwell GPUs.

Summary by CodeRabbit

  • Bug Fixes
    • Improved error handling for import failures with more specific exception catching and clearer diagnostic messages.
    • Enhanced validation with informative error messaging when required configuration attributes are missing, guiding users on calibration and export prerequisites.

@debo3 debo3 requested a review from a team as a code owner February 16, 2026 21:00
@debo3 debo3 requested a review from ajrasane February 16, 2026 21:00
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 16, 2026

📝 Walkthrough

Walkthrough

This PR expands error handling in quantization utilities: the deepspeed compatibility function now catches additional exception types (FileNotFoundError, RuntimeError), and NVFP4 tensor weight scaling computation replaces assertion with explicit validation and error messaging.

Changes

Cohort / File(s) Summary
Exception Handling Enhancements
modelopt/torch/quantization/plugins/transformers.py
Expanded exception handling in make_deepspeed_compatible to catch FileNotFoundError and RuntimeError alongside ImportError, with comments documenting potential causes (missing deepspeed package, missing nvcc/toolchain, or initialization failures). Behavior remains silent no-op on import failure.
Validation Logic Improvement
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Replaced assertion with guarded check in get_weights_scaling_factor_2_from_quantizer. Now explicitly validates that _amax attribute exists and is not None, raising ValueError with descriptive message about calibration and export requirements if validation fails.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main
Title check ✅ Passed The title accurately and concisely summarizes both main changes: DeepSpeed import handling improvement and NVFP4 error messaging enhancement.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Fix all issues with AI agents
Verify each finding against the current code and only fix it if needed.


In `@modelopt/torch/quantization/plugins/transformers.py`:
- Around line 29-34: The current except clause "except (ImportError,
FileNotFoundError, RuntimeError):" can silently swallow real DeepSpeed
initialization errors; update it to capture the exception object (e.g., "except
(ImportError, FileNotFoundError, RuntimeError) as e:") and then handle
RuntimeError specially by either logging a warning with the exception details
(processLogger.warn / warning) or by inspecting e.args[0] / str(e) and only
suppressing when the message indicates missing nvcc/cuda (check for substrings
like "nvcc" or "cuda"), otherwise re-raise the RuntimeError so real failures
surface; keep FileNotFoundError and ImportError handling as-is but include the
error in logs for diagnostics.
🧹 Nitpick comments (1)
🤖 Fix all nitpicks with AI agents
Verify each finding against the current code and only fix it if needed.


In `@modelopt/torch/quantization/plugins/transformers.py`:
- Around line 29-34: The current except clause "except (ImportError,
FileNotFoundError, RuntimeError):" can silently swallow real DeepSpeed
initialization errors; update it to capture the exception object (e.g., "except
(ImportError, FileNotFoundError, RuntimeError) as e:") and then handle
RuntimeError specially by either logging a warning with the exception details
(processLogger.warn / warning) or by inspecting e.args[0] / str(e) and only
suppressing when the message indicates missing nvcc/cuda (check for substrings
like "nvcc" or "cuda"), otherwise re-raise the RuntimeError so real failures
surface; keep FileNotFoundError and ImportError handling as-is but include the
error in logs for diagnostics.
modelopt/torch/quantization/plugins/transformers.py (1)

29-34: RuntimeError is broad and may silently mask real DeepSpeed initialization issues.

FileNotFoundError for missing nvcc is well-justified. However, catching RuntimeError is very broad — it could silently swallow legitimate DeepSpeed configuration errors (e.g., version mismatches, NCCL failures) and leave users without Zero3 compatibility when they actually need it, with no indication of the problem.

Consider either:

  1. Logging a warning when RuntimeError is caught, so users have visibility.
  2. Narrowing by inspecting the exception message (e.g., checking for nvcc or cuda substrings) before suppressing.
Option 1: Add a warning log
+import logging
+
+logger = logging.getLogger(__name__)
+
 def make_deepspeed_compatible(model: nn.Module):
     """Make the model compatible with DeepSpeed."""
     try:
         from deepspeed.runtime.zero.parameter_offload import ZeROOrderedDict
-    except (ImportError, FileNotFoundError, RuntimeError):
+    except ImportError:
+        return
+    except (FileNotFoundError, RuntimeError) as e:
+        logger.debug("DeepSpeed import failed, skipping compatibility check: %s", e)
         return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/transformers.py` around lines 29 - 34,
The current except clause "except (ImportError, FileNotFoundError,
RuntimeError):" can silently swallow real DeepSpeed initialization errors;
update it to capture the exception object (e.g., "except (ImportError,
FileNotFoundError, RuntimeError) as e:") and then handle RuntimeError specially
by either logging a warning with the exception details (processLogger.warn /
warning) or by inspecting e.args[0] / str(e) and only suppressing when the
message indicates missing nvcc/cuda (check for substrings like "nvcc" or
"cuda"), otherwise re-raise the RuntimeError so real failures surface; keep
FileNotFoundError and ImportError handling as-is but include the error in logs
for diagnostics.

@debo3 debo3 changed the title minor fixes Fix DeepSpeed import crash on runtime-only CUDA and improve NVFP4 uncalibrated weight error Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant